In [722]:
import numpy as np
import scipy.linalg
import scipy.weave
import sklearn
from sklearn.base import BaseEstimator
class RobustPCA(BaseEstimator):
'''Robust PCA
Candès, Emmanuel J., et al. "Robust principal component analysis?." Journal of the ACM (JACM) 58.3 (2011): 11.
http://arxiv.org/abs/0912.3599
'''
def __nuclear_prox(self, A, r=1.0):
'''Proximal operator for scaled nuclear norm:
Y* <- argmin_Y r * ||Y||_* + 1/2 * ||Y - A||_F^2
Arguments:
A -- (ndarray) input matrix
r -- (float>0) scaling factor
Returns:
Y -- (ndarray) if A = USV', then Y = UTV'
where T = max(S - r, 0)
'''
U, S, V = scipy.linalg.svd(A, full_matrices=False)
T = np.maximum(S - r, 0.0)
Y = (U * T).dot(V)
return Y
def __l1_prox(self, A, r=1.0):
'''Proximal operator for entry-wise matrix l1 norm:
Y* <- argmin_Y r * ||Y||_1 + 1/2 * ||Y - A||_F^2
Arguments:
A -- (ndarray) input matrix
r -- (float>0) scaling factor
Returns:
Y -- (ndarray) Y = A after shrinkage
'''
Y = np.zeros_like(A)
numel = A.size
shrinkage = r"""
for (int i = 0; i < numel; i++) {
Y[i] = 0;
if (A[i] - r > 0) {
Y[i] = A[i] - r;
} else if (A[i] + r <= 0) {
Y[i] = A[i] + r;
}
}
"""
scipy.weave.inline(shrinkage, ['numel', 'A', 'r', 'Y'])
return Y
def __cost(self, Y, Z):
'''Get the cost of an RPCA solution.
Arguments:
Y -- (ndarray) the low-rank component
Z -- (ndarray) the sparse component
alpha -- (float>0) the balancing factor
Returns:
total, nuclear_norm, l1_norm -- (list of floats)
'''
nuclear_norm = scipy.linalg.svd(Y,
full_matrices=False,
compute_uv=False).sum()
l1_norm = np.abs(Z).sum()
return nuclear_norm + self.alpha_ * l1_norm, nuclear_norm, l1_norm
def __init__(self, alpha=None, max_iter=100, verbose=False):
'''
Arguments:
alpha -- (float > 0) weight between low-rank and noise term
If left as None, alpha will be automatically set to
sqrt(max(X.shape))
max_iter -- (int > 0) maximum number of iterations
'''
self.alpha = alpha
self.max_iter = max_iter
self.verbose = verbose
def fit(self, X):
'''Fit the robust PCA model to a matrix X'''
self.fit_transform(X)
return self
def fit_transform(self, X):
# Some magic numbers for dynamic augmenting penalties in ADMM.
# Changing these shouldn't effect correctness, only convergence rate.
RHO_MIN = 1e0
RHO_MAX = 1e5
MAX_RATIO = 2e0
SCALE_FACTOR = 1.5e0
ABS_TOL = 1e-4
REL_TOL = 1e-3
# update rules:
# Y+ <- nuclear_prox(X - Z - W, 1/rho)
# Z+ <- l1_prox(X - Y - W, alpha/rho)
# W+ <- W + Y + Z - X
# Initialize
rho = RHO_MIN
# Scale the data to a workable range
X = X.astype(np.float)
Xmin = np.min(X)
rescale = max(1e-8, np.max(X - Xmin))
Xt = (X - Xmin) / rescale
Y = Xt.copy()
Z = np.zeros_like(Xt)
W = np.zeros_like(Xt)
norm_X = scipy.linalg.norm(Xt)
if self.alpha is None:
self.alpha_ = max(Xt.shape)**(-0.5)
else:
self.alpha_ = self.alpha
m = X.size
_DIAG = {
'err_primal': [],
'err_dual': [],
'eps_primal': [],
'eps_dual': [],
'rho': []
}
for t in range(self.max_iter):
Y = self.__nuclear_prox(Xt - Z - W, 1.0/rho)
Z_old = Z.copy()
Z = self.__l1_prox(Xt - Y - W, self.alpha_ / rho)
residual_pri = Y + Z - Xt
residual_dual = Z - Z_old
res_norm_pri = scipy.linalg.norm(residual_pri)
res_norm_dual = rho * scipy.linalg.norm(residual_dual)
W = W + residual_pri
eps_pri = np.sqrt(m) * ABS_TOL + REL_TOL * max(scipy.linalg.norm(Y), scipy.linalg.norm(Z), norm_X)
eps_dual = np.sqrt(m) * ABS_TOL + REL_TOL * scipy.linalg.norm(W)
_DIAG['eps_primal'].append(eps_pri)
_DIAG['eps_dual' ].append(eps_dual)
_DIAG['err_primal'].append(res_norm_pri)
_DIAG['err_dual' ].append(res_norm_dual)
_DIAG['rho' ].append(rho)
if res_norm_pri <= eps_pri and res_norm_dual <= eps_dual:
break
if res_norm_pri > MAX_RATIO * res_norm_dual and rho * SCALE_FACTOR <= RHO_MAX:
rho = rho * SCALE_FACTOR
W = W / SCALE_FACTOR
elif res_norm_dual > MAX_RATIO * res_norm_pri and rho / SCALE_FACTOR >= RHO_MIN:
rho = rho / SCALE_FACTOR
W = W * SCALE_FACTOR
if self.verbose:
if t < self.max_iter - 1:
print 'Converged in %d steps' % t
else:
print 'Reached maximum iterations'
# Scale back up to the original data scale
Z = (Z + Xmin) * rescale
self.embedding_= X - Z
_DIAG['cost'] = self.__cost(self.embedding_, Z)
self.diagnostics_ = _DIAG
return self.embedding_
In [723]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn
seaborn.set(style='white')
%matplotlib inline
In [724]:
import sklearn.datasets
In [725]:
# Load in the example images from sklearn
data = sklearn.datasets.load_sample_images()
In [726]:
# Convert each image from RGB to grayscale
X = [np.mean(D, axis=-1) for D in data.images]
In [736]:
# Pull out the first image + laplace noise
dx = X[0] + np.random.laplace(scale=5, size=X[0].shape)
In [737]:
plt.subplot(121)
plt.imshow(X[0], interpolation='none')
plt.axis('off')
plt.title('Original')
plt.subplot(122)
plt.imshow(dx, interpolation='none')
plt.axis('off')
plt.title('Noisy')
plt.tight_layout()
In [738]:
# Build a model object
M = RobustPCA(verbose=True)
In [739]:
M.fit(dx)
Out[739]:
In [740]:
plt.semilogy(M.diagnostics_['err_primal'], label='Primal error')
plt.semilogy(M.diagnostics_['err_dual'], label='Dual error')
plt.xlabel('Iterations')
plt.legend()
plt.tight_layout()
In [741]:
plt.plot(M.diagnostics_['rho'], label=r'$\rho$')
plt.legend()
plt.title('Augmenting term factor')
plt.tight_layout()
In [742]:
# Helper function to pull out the normalized spectrum of a matrix
def spectrum(X, norm=True):
v = scipy.linalg.svd(X, compute_uv=False)
if norm:
v = v / v.max()
return v
In [743]:
# How do the spectra compare?
plt.bar(np.arange(10), spectrum(dx)[:10], width=0.45, label='Input X')
plt.bar(np.arange(10) + 0.5, spectrum(M.embedding_)[:10], width=0.45, color='r', label='Low-rank approximation')
plt.xticks(0.45 + np.arange(10), range(1,11))
plt.xlabel(r'$i$')
plt.ylabel(r'$\sigma_i / \sigma_1$')
plt.title('Normalized singular value distribution')
plt.legend()
plt.tight_layout()
In [744]:
# How do the images look?
plt.subplot(131)
plt.imshow(dx, interpolation='none')
plt.title('Input')
plt.axis('off')
plt.subplot(132)
plt.imshow(M.embedding_, interpolation='none')
plt.title('Low-rank approximation')
plt.axis('off')
plt.subplot(133)
plt.imshow(dx - M.embedding_, interpolation='none')
plt.title('Residual')
plt.axis('off')
plt.tight_layout()